###############################################################################
# Name: ScCommand.py                                                          #
# Purpose: Enumerate modified, added, deleted files in a list                 #
# Author: Kevin D. Smith <Kevin.Smith@sixquickrun.com>                        #
# Author: Cody Precord <cprecord@editra.org>                                  #
# Copyright: (c) 2008 Cody Precord <staff@editra.org>                         #
# License: wxWindows License                                                  #
###############################################################################

"""
Source Control Commands

@author: Cody Precord <cprecord@editra.org>
@author: Kevin D. Smith <Kevin.Smith@sixquickrun.com>

"""

__author__ = "Kevin D. Smith <Kevin.Smith@sixquickrun.com>"
__svnid__ = "$Id: ScCommand.py 1406 2011-06-06 23:53:08Z CodyPrecord@gmail.com $"
__revision__ = "$Revision: 1406 $"

#--------------------------------------------------------------------------#
# Imports
import time
import wx
import os
import threading
import shutil
import subprocess
import tempfile

# Editra Libraries
import ed_thread

# Local Imports
from projects.ConfigDialog import ConfigData
import projects.diffwin as diffwin

#--------------------------------------------------------------------------#
# Globals

# Error Codes
SC_ERROR_NONE = 0
SC_ERROR_RETRIEVAL_FAIL = 1

#--------------------------------------------------------------------------#
# Event Types

# Source Control Command has finished
ppEVT_CMD_COMPLETE = wx.NewEventType()
EVT_CMD_COMPLETE = wx.PyEventBinder(ppEVT_CMD_COMPLETE, 1)

# A diff job has completed
ppEVT_DIFF_COMPLETE = wx.NewEventType()
EVT_DIFF_COMPLETE = wx.PyEventBinder(ppEVT_DIFF_COMPLETE, 1)

# A status job event
ppEVT_STATUS = wx.NewEventType()
EVT_STATUS = wx.PyEventBinder(ppEVT_STATUS, 1)

class SourceControlEvent(wx.PyCommandEvent):
    """Base event to signal source controller events"""
    def __init__(self, etype, eid, value=None, err=SC_ERROR_NONE):
        super(SourceControlEvent, self).__init__(etype, eid)

        # Attributes
        self._value = value
        self._err = err

    Error = property(lambda self: self._err,
                     lambda self, err: setattr(self, '_err', err))
    Value = property(lambda self: self._value,
                     lambda self, v: setattr(self, '_value', v))

#--------------------------------------------------------------------------#

class ScCommandThread(threading.Thread):
    """Run a task in its own thread."""
    def __init__(self, parent, task, etype, args=(), kwargs=dict()):
        """Initialize the ScCommandThread. All *args and **kwargs are passed
        to the task.

        @param parent: Parent Window/EventHandler to receive the events
                       generated by the process.
        @param task: callable to run. must return tuple (rval, errmsg)
        @param etype: callback event type to post

        """
        super(ScCommandThread, self).__init__()

        # Attributes
        self.cancel = False         # Abort task
        self._parent = parent       # Parent Window/Event Handler
        self._pid = parent.GetId()  # Parent ID
        self.task = task            # Task method to run
        self.etype = etype
        self._args = args
        self._kwargs = kwargs

    def run(self):
        """Start running the task"""
        value, err = self.task(*self._args, **self._kwargs)

        # Post Results back to parent window
        evt = SourceControlEvent(self.etype, self._pid, value, err)
        wx.PostEvent(self._parent, evt)

#--------------------------------------------------------------------------#

def SourceControlTask(parent, pid, task, etype, args=None, kwargs=None):
    """Job function for EdThreadPool
    @param parent: callback window
    @param pid: parent window ID
    @param task: callable
    @param etype: event type (for callback event)

    """
    assert callable(task)

    if args is None: 
        args = ()
    if kwargs is None:
        kwargs = dict()

    rval = task(*args, **kwargs)

    if isinstance(rval, tuple) and len(rval) == 2:
        value, err = rval
    else:
        value = rval
        err = SC_ERROR_NONE

    # Post Results back to parent window
    evt = SourceControlEvent(etype, pid, value, err)
    wx.PostEvent(parent, evt)

#--------------------------------------------------------------------------#

class SourceController(object):
    """Source control command controller"""
    CACHE = dict()

    def __init__(self, owner):
        """Create the SourceController
        @param owner: Owner window

        """
        super(SourceController, self).__init__()

        # Attributes
        self._parent = owner
        self._pid = self._parent.GetId()
        self.config = ConfigData() # Singleton config data instance
        self.tempdir = None
        self.scThreads = {}

        # Number of seconds to allow a source control command to run
        # before timing out
        self.scTimeout = 60

    def __del__(self):
        # Clean up tempdir
        if self.tempdir:
            shutil.rmtree(self.tempdir, ignore_errors=True)
        diffwin.CleanupTempFiles()

        # Stop any currently running source control threads
        self.CleanupThreads(False)

    def _TimeoutCommand(self, callback, *args, **kwargs):
        """ Run command, but kill it if it takes longer than `timeout` secs
        @param callback: callable to call with results from command

        """
        result = []
        def resultWrapper(result, *args, **kwargs):
            """ Function to catch output of threaded method """
            args = list(args)
            method = args.pop(0)
            result.append(method(*args, **kwargs))

        # Insert result object to catch output
        args = list(args)
        args.insert(0, result)

        # Start thread
        t = threading.Thread(target=resultWrapper, args=args, kwargs=kwargs)
        t.start()

        # Wait up till timeout for thread to exit
        self.scThreads[t] = False
        t.join(self.scTimeout)

        if t.isAlive():
            # Timed out
            return False
        else:
            del self.scThreads[t]
            del t

        if callback is not None and len(result):
            callback(result[0])

        return True

    def CleanupThreads(self, deadonly=True):
        """Cleanup worker threads
        @keyword deadonly: only cleanup dead threads
        @return: number of threads that are still alive

        """
        if not deadonly:
            for t in self.scThreads.keys():
                del self.scThreads[t]
        else:
            dead = list()
            for t in self.scThreads:
                if not t.isAlive():
                    dead.append(t)
            for t in dead:
                del self.scThreads[t]

        return len(self.scThreads)

    def CompareRevisions(self, path, rev1=None, date1=None, rev2=None, date2=None):
        """
        Compare the playpen path to a specific revision, or compare two
        revisions

        Required Arguments:
        path -- absolute path of file to compare

        Keyword Arguments:
        rev1/date1 -- first file revision/date to compare against
        rev2/date2 -- second file revision/date to compare against

        """
        ed_thread.EdThreadPool().QueueJob(SourceControlTask, self._parent, 
                                          self._parent.Id, self.Diff, 
                                          ppEVT_DIFF_COMPLETE,
                                          args=(path, ),
                                          kwargs=dict(rev1=rev1, date1=date1, 
                                                      rev2=rev2, date2=date2))

    def Diff(self, path, rev1, date1, rev2, date2):
        """ Do the actual diff of two files by sending the files
        to be compared to the appropriate diff program.

        @return: tuple (None, err_code)
        @todo: cleanup and simplify this method

        """
        # Only do files
        if os.path.isdir(path):
            for fname in os.listdir(path):
                self.CompareRevisions(fname, rev1=rev1, date1=date1,
                                             rev2=rev2, date2=date2)
            return

        # Check if path is under source control
        sc = self.GetSCSystem(path)
        if sc is None:
            return None

        content1 = content2 = ext1 = ext2 = None

        # Grab the first specified revision
        if rev1 or date1:
            content1 = sc['instance'].fetch([path], rev=rev1, date=date1)
            if content1 and content1[0] is None:
                return (None, SC_ERROR_RETRIEVAL_FAIL)
            else:
                content1 = content1[0]
                if rev1:
                    ext1 = rev1
                elif date1:
                    ext1 = date1

        # Grab the second specified revision
        if rev2 or date2:
            content2 = sc['instance'].fetch([path], rev=rev2, date=date2)
            if content2 and content2[0] is None:
                return (None, SC_ERROR_RETRIEVAL_FAIL)
            else:
                content2 = content2[0]
                if rev2:
                    ext2 = rev2
                elif date2:
                    ext2 = date2

        if not (rev1 or date1 or rev2 or date2):
            content1 = sc['instance'].fetch([path])
            if content1 and content1[0] is None:
                return (None, SC_ERROR_RETRIEVAL_FAIL)
            else:
                content1 = content1[0]
                ext1 = 'previous'

        if not self.tempdir:
            self.tempdir = tempfile.mkdtemp()

        # Write temporary files
        path1 = path2 = None
        if content1 and content2:
            path = os.path.join(self.tempdir, os.path.basename(path))
            path1 = '%s.%s' % (path, ext1)
            path2 = '%s.%s' % (path, ext2)
            tfile = open(path1, 'w')
            tfile.write(content1)
            tfile.close()
            tfile2 = open(path2, 'w')
            tfile2.write(content2)
            tfile2.close()
        elif content1:
            path1 = path
            path = os.path.join(self.tempdir, os.path.basename(path))
            path2 = '%s.%s' % (path, ext1)
            tfile = open(path2, 'w')
            tfile.write(content1)
            tfile.close()
        elif content2:
            path1 = path
            path = os.path.join(self.tempdir, os.path.basename(path))
            path2 = '%s.%s' % (path, ext2)
            tfile2 = open(path2, 'w')
            tfile2.write(content2)
            tfile2.close()

        # If both content retrieval failed exit with error
        if None in (path2, path1):
            return (None, SC_ERROR_RETRIEVAL_FAIL)

        # Run comparison program
        if self.config.getBuiltinDiff() or not self.config.getDiffProgram():
            diffwin.GenerateDiff(path2, path1, html=True)
        elif isinstance(path2, basestring) and isinstance(path2, basestring):
            subprocess.call([self.config.getDiffProgram(), path2, path1])
        else:
            return (None, SC_ERROR_RETRIEVAL_FAIL)

        return (None, SC_ERROR_NONE)

    def GetSCSystem(self, path):
        """ Determine source control system being used on path if any
        @todo: possibly cache paths that are found to be under source control
               and the systems the belong to in order to improve performance

        """
        # XXX: Experimental caching of paths to speed up commands.
        #      Currently the improvements are quite measurable, need
        #      to monitor memory usage and end cases though.
        systems = self.config.getSCSystems()
        if path in SourceController.CACHE:
            return systems[SourceController.CACHE[path]]

        for key, value in systems.items():
            if value['instance'].isControlled(path):
                SourceController.CACHE[path] = key
                return value

    def IsSingleRepository(self, paths):
        """
        Are all paths from the same repository ?

        Required Arguments:
        nodes -- list of paths to test

        Returns: boolean indicating if all nodes are in the same repository
            (True), or if they are not (False).

        """
        previous = ''
        for path in paths:
            try:
                reppath = self.GetSCSystem(path)['instance'].getRepository(path)
            except:
                continue

            if not previous:
                previous = reppath
            elif previous != reppath:
                return False
        return True

    def ScCommand(self, nodes, command, callback=None, **options):
        """
        Run a source control command

        Required Arguments:
        nodes -- selected tree nodes [(treeitem, dict(path='', watcher=thread)]
        command -- name of command type to run

        """
        cjob = ScCommandThread(self._parent, self.RunScCommand,
                               ppEVT_CMD_COMPLETE,
                               args=(nodes, command, callback),
                               kwargs=options)
        cjob.setDaemon(True)
        cjob.start()
        self.scThreads[cjob] = False

    def RunScCommand(self, nodes, command, callback, **options):
        """Does the running of the command
        @param nodes: list [(node, data), (node2, data2), ...]
        @param command: command string
        @param callback: callable or None
        @return: (command, None)

        """
        concurrentcmds = ['status', 'history']
        NODE, DATA, SC = 0, 1, 2
        nodeinfo = []
        sc = None
        for node, data in nodes:
            # node, data, sc
            info = [node, data, None]

            # See if the node already has an operation running
            i = 0
            while data.get('sclock', None):
                time.sleep(1)
                i += 1
                if i > self.scTimeout:
                    return (None, None)

            # See if the node has a path associated
            # Technically, all nodes should (except the root node)
            if 'path' not in data:
                continue

            # Determine source control system
            sc = self.GetSCSystem(data['path'])
            if sc is None:
                if os.path.isdir(data['path']) or command == 'add':
                    sc = self.GetSCSystem(os.path.dirname(data['path']))
                    if sc is None:
                        continue
                else:
                    continue

            info[SC] = sc

            nodeinfo.append(info)

        # Check if the sc was found
        if sc is None:
            return (None, None)

        # Lock node while command is running
        if command not in concurrentcmds:
            for node, data, sc in nodeinfo:
                data['sclock'] = command

        rc = True
        try:
            # Find correct method
            method = getattr(sc['instance'], command, None)
            if method is not None:
                # Run command (only if it isn't the status command)
                if command != 'status':
                    if 'outhook' in options:
                        sc['instance'].setOutputHook(options['outhook'])
                        del options['outhook']

                    rc = self._TimeoutCommand(callback, method,
                                              [x[DATA]['path'] for x in nodeinfo],
                                              **options)

                    # Make sure the output hook has been cleared
                    sc['instance'].clearOutputHook()
        finally:
            # Only update status if last command didn't time out
            if command not in ['history', 'revert', 'update'] and rc:
                for node, data, sc in nodeinfo:
                    self.StatusWithTimeout(sc, node, data)

            # Unlock
            if command not in concurrentcmds:
                for node, data, sc in nodeinfo:
                    del data['sclock']

        return (command, None)

    def StatusWithTimeout(self, sc, node, data, recursive=False):
        """Run a SourceControl status command with a timeout
        @param sc: SourceControll instance
        @param node: tree node, data
        @param data: data dict(path='')

        """
        status = {}
        try:
            rval = self._TimeoutCommand(None, sc['instance'].status,
                                        [data['path']],
                                        recursive=recursive,
                                        status=status)
        except Exception, msg:
            # TODO: needs logging
            print "ERROR:", msg

        evt = SourceControlEvent(ppEVT_STATUS, self._pid,
                                 (node, data, status, sc))
        wx.PostEvent(self._parent, evt)
